import numpy as np
import paddle
from paddle.callbacks import ModelCheckpoint
from paddle.io import DataLoader
from YOLO.downstream import build_model
from DataPipe import build_dataset

def train():
    vehicle_dataset = build_dataset(mode='train')
    vehicle_dataloader = DataLoader(vehicle_dataset, batch_size=3)

    val_dataloader = DataLoader(build_dataset('val'), batch_size=3)

    model = build_model(path='./model_weights/final.pdparams')
    model.fit(train_data=vehicle_dataloader, eval_data=val_dataloader, epochs=3, verbose=1, callbacks=[
        ModelCheckpoint(save_dir='./model_weights/')
    ])

if __name__ == '__main__':
    train()